﻿using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Atmk.WaterMeter.MIS.Commons.Enums;
using Atmk.WaterMeter.MIS.Commons.Interfaces;
using Atmk.WaterMeter.MIS.Commons.Utils;
using Atmk.WaterMeter.MIS.Datas;
using Atmk.WaterMeter.MIS.Entities.Common;
using Atmk.WaterMeter.MIS.Entities.Enums;
using Atmk.WaterMeter.MIS.Entities.Models;
using Microsoft.EntityFrameworkCore;

namespace Atmk.WaterMeter.MIS.Services.DataAccess.Base
{
    /// <inheritdoc />
    /// <summary>
    /// 数据操作层基类
    /// </summary>
    public class BaseRepository: IRepository
    {
        protected readonly Context Context;

        public BaseRepository()
        {
            Context = ContextBuilder.Build();
        }

        public BaseRepository(Context context)
        {
            Context = context;
        }

        public virtual int SaveChanges()
        {
            return Context.SaveChanges();
        }

        /// <inheritdoc />
        public virtual int Add(IRecord entity)
        {
            using (var context = ContextBuilder.Build())
            {
                context.Set<IRecord>().Add(entity);
                return context.SaveChanges();
            }
        }
        public virtual int Add<T>(T entity)
        {
            using (var context =ContextBuilder.Build())
            {
                context.Add(entity);
                return context.SaveChanges();
            }
        }
        /// <inheritdoc />
        public virtual int AddRang(params IRecord[] entities)
        {
            var enumerable = entities as IRecord[] ?? entities.ToArray();
            Context.Set<IRecord>().AddRange(enumerable);
            return Context.SaveChanges();
        }
        public virtual int AddRang<T>(T entity)
        {
            using (var context = ContextBuilder.Build())
            {
                context.AddRange(entity);
                return context.SaveChanges();
            }
        }
        public virtual int AddRang2(IEnumerable<IRecord> entities)
        {
            Context.Set<IRecord>().AddRange(entities);
            return Context.SaveChanges();
        }

        /// <inheritdoc />
        public virtual int Delete(params IRecord[] entities)
        {
            using (var context = ContextBuilder.Build())
            {
                foreach (var item in entities)
                {
                    item.RecordState = RecordStateEnum.Deleted;
                    item.ModifiedLog("user", "state change");
                }
                context.Set<IRecord>().UpdateRange(entities);
                return context.SaveChanges();
            }
        }
        /// <inheritdoc />
        /// <summary>
        /// 彻底删除
        /// </summary>
        /// <param name="entities"></param>
        /// <returns></returns>
        public virtual int Remove(params IRecord[] entities)
        {
            using (var context = ContextBuilder.Build())
            {
                if (entities != null && entities.Length > 0)
                {
                    context.Set<IRecord>().RemoveRange(entities);
                }
                return context.SaveChanges();
            }
           
        }
        public virtual int Remove<T>(T entities)
        {
            using (var _context = ContextBuilder.Build())
            {
                _context.RemoveRange(entities);
                return _context.SaveChanges();
            }

        }
        public virtual int Remove<T>() where T : BaseRecord
        {
            Context.Set<T>().RemoveRange(Context.Set<T>().ToArray());
            return Context.SaveChanges();
        }

        /// <inheritdoc />
        /// <summary>
        /// 修改对象
        /// </summary>
        /// <param name="entities"></param>
        /// <returns></returns>
        public virtual int Update(params IRecord[] entities)
        {
            using (var context = ContextBuilder.Build())
            {
                context.Set<IRecord>().UpdateRange(entities);
                return context.SaveChanges();
            }
            //Context.Set<IRecord>().UpdateRange(entities);
            //return Context.SaveChanges();
        }
        public virtual int Update<T>(T entities)
        {
            using (var _context = ContextBuilder.Build())
            {
                _context.UpdateRange(entities);
                return _context.SaveChanges();
            }
        }
        /// <summary>
        /// 查询所有
        /// </summary>
        /// <returns></returns>
        public virtual IEnumerable<T> FindAll<T>()where T:class 
        {
            return Context.Set<T>().ToList();
        }
        /// <summary>
        /// 仅查询
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <returns></returns>
        public virtual IEnumerable<T> FindAllAsNoTracking<T>() where T : class
        {
            return Context.Set<T>().AsNoTracking().ToList();
        }
        /// <inheritdoc />
        /// <summary>
        /// 根据Id查询
        /// </summary>
        public virtual IRecord FindById(string id)
        {
            var results = Context.Set<IRecord>().Where(x => x.Id == id).ToList();
            if(results.Count==0)
                throw new Exception("查询不到数据");
            return results[0];
        }
        /// <inheritdoc />
        /// <summary>
        /// 根据Id查询
        /// </summary>
        public T FindById<T>(string id) where T : BaseRecord
        {
            var results = Context.Set<T>().Where(x => x.Id == id).ToList();
            if (results.Count == 0)
                throw new Exception("查询不到数据");
            return results[0];
        }

        /// <summary>
        /// Linq查询
        /// </summary>
        /// <returns></returns>
        public virtual async Task<PaginatedList<IRecord>> FindIndex<T>(string sortOrder,
            string currentFilter,
            string searchString,
            int? page) where T:class 
        {
            var district = typeof(T).BaseType.Name == typeof(BaseEntity).Name ? EntityFind<T>(sortOrder) : RecordFind<T>();
            //默认值，后续需要添加配置表
            var pageSize = page ?? district.Count();
            return await PaginatedList<IRecord>.CreateAsync(district.AsNoTracking(), page ?? 1, pageSize);
        }

        private IQueryable<IRecord> RecordFind<T>() where T : class
        {
            var district =
                ((IQueryable<IRecord>) Context.Set<T>()).Where(s => s.RecordState == RecordStateEnum.Normal);
            return district;
           
        }

        private IQueryable<IRecord> EntityFind<T>(string sortOrder) where T : class
        {
            var district =
                ((IQueryable<IEntity>) Context.Set<T>()).Where(s => s.RecordState == RecordStateEnum.Normal);
            switch (sortOrder)
            {
                case "Name":
                    district = district.OrderBy(s => s.Name);
                    break;
                default:
                    district = district.OrderBy(s => s.CreateTime);
                    break;
            }
            return district;
        }
    }
}